{
 "metadata": {
  "name": "",
  "signature": "sha256:e5c488b0f45c26267e6b3f2f6237c97d2e01468726aaa0f2f41fcc3920ca448e"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import numpy as np\n",
      "from matplotlib import pyplot as plt\n",
      "%matplotlib inline"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 1
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def sigmoid(z):\n",
      "    return 1 / (1 + np.exp(-z))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 2
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def compute_loss(a, y):\n",
      "    return -(y*np.log(a) + (1-y)*np.log(1-a))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 3
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def train(data, label, epochs, lr):\n",
      "    num = len(data)\n",
      "    dim = data.shape[1]\n",
      "    gradWs = np.zeros(dim)\n",
      "    gradBs = 0\n",
      "    losses = []\n",
      "    reg_rate = 0.001  # \u6b63\u5219\u9879\u7cfb\u6570\n",
      "    weight = np.zeros(dim)\n",
      "    bias = 0\n",
      "    for i in range(epochs):\n",
      "        loss = 0\n",
      "        gradW = np.zeros(dim)\n",
      "        gradB = 0\n",
      "        for j in range(num):\n",
      "            z = np.dot(data[j], weight) + bias\n",
      "            a = sigmoid(z)\n",
      "            loss += compute_loss(a, label[j])\n",
      "            #for k in range(dim):\n",
      "            #    gradW[k] += (-1) * (label[j] - a) * data[j, k] + 2 * reg_rate * gradW[k]\n",
      "            gradW += data[j]*(a-label[j])\n",
      "            gradB += (a-label[j])\n",
      "        losses.append(loss/num)\n",
      "        gradWs += gradW**2\n",
      "        gradBs += gradB**2\n",
      "        weight -= lr * gradW/(gradWs**0.5)\n",
      "        bias -= lr * gradB/(gradBs**0.5)\n",
      "        \n",
      "        if i % 3 == 0:\n",
      "            # loss = 0\n",
      "            acc = 0\n",
      "            result = np.zeros(num)\n",
      "            for j in range(num):\n",
      "                y_pre = weight.dot(data[j, :]) + bias\n",
      "                a = 1 / (1 + np.exp(-y_pre))\n",
      "                if a >= 0.5:\n",
      "                    result[j] = 1\n",
      "                else:\n",
      "                    result[j] = 0\n",
      "\n",
      "                if result[j] == label[j]:\n",
      "                    acc += 1.0\n",
      "                loss += (-1) * (label[j] * np.log(a) + (1 - label[j]) * np.log(1 - a))\n",
      "            print('after {} epochs, the loss on train data is:'.format(i), loss / num)\n",
      "            print('after {} epochs, the acc on train data is:'.format(i), acc / num)\n",
      "    print (losses)\n",
      "    plt.plot(losses)\n",
      "    plt.show\n",
      "    return weight, bias\n",
      "    "
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 4
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def train_mat(data, label, epochs, lr):\n",
      "    num, dim = data.shape\n",
      "    weight = np.zeros(dim+1)\n",
      "    data = np.concatenate((np.ones((num, 1)), data), axis=1)\n",
      "    gradS = np.zeros(dim+1)\n",
      "    data_t = data.transpose()\n",
      "    for i in range(epochs):\n",
      "        z = np.dot(data, weight)\n",
      "        a = sigmoid(z)\n",
      "        grad = np.dot(data_t, (a - label))\n",
      "        gradS += grad**2\n",
      "        weight -= lr * grad/(gradS**0.5)\n",
      "        if i%3 == 0:\n",
      "            z = np.dot(data, weight)\n",
      "            a = sigmoid(z)\n",
      "            acc = len(np.where(a>0.5)[0])\n",
      "            print (acc)\n",
      "            print('after {} epochs, the acc on train data is:'.format(i), acc / num)\n",
      "    return weight, 0"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 5
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "data = np.load('train_data.npy')\n",
      "label = np.load('train_label.npy')\n",
      "epochs = 30\n",
      "lr = 1\n",
      "weight, bias = train(data, label, epochs, lr)\n",
      "#weight, bias = train_mat(data, label, epochs, lr)\n",
      "print (weight)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "after 0 epochs, the loss on train data is: nan\n",
        "after 0 epochs, the acc on train data is: 0.6671428571428571\n",
        "after 3 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 3 epochs, the acc on train data is: 0.8925714285714286\n",
        "after 6 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 6 epochs, the acc on train data is: 0.9174285714285715\n",
        "after 9 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 9 epochs, the acc on train data is: 0.9171428571428571\n",
        "after 12 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 12 epochs, the acc on train data is: 0.9242857142857143\n",
        "after 15 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 15 epochs, the acc on train data is: 0.9171428571428571\n",
        "after 18 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 18 epochs, the acc on train data is: 0.88\n",
        "after 21 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 21 epochs, the acc on train data is: 0.9117142857142857\n",
        "after 24 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 24 epochs, the acc on train data is: 0.9245714285714286\n",
        "after 27 epochs, the loss on train data is:"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        " nan\n",
        "after 27 epochs, the acc on train data is: 0.92\n",
        "[0.69314718055993774, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "\n",
        "[-0.38126145 -0.20966912  0.22632263  1.03109394  0.50508149  1.12808021\n",
        "  2.15981479  0.46267897  0.84314915  0.09434013 -0.38605399 -0.21365656\n",
        "  0.07157737  0.02383779  1.00486464  0.97756852  1.27711304  0.20648427\n",
        "  0.09837524  1.31668992  0.21858887  0.29521341  2.35542328  0.48940357\n",
        " -1.45354259 -1.2511838  -1.46798289  0.58280617 -1.45029877 -0.66465604\n",
        " -1.54475628 -1.12309543 -1.00240312 -0.8543639  -1.43865152  0.72599227\n",
        " -0.04042505 -0.25982353 -1.18106951 -1.25467656 -1.94117793 -1.77684932\n",
        " -1.16421411 -1.30797276 -0.86707952 -1.4741805  -1.76720885 -1.85117307\n",
        " -1.65366526 -0.96339721 -1.8319499   0.29240368  3.00946295  1.64257379\n",
        "  0.06520554  0.25511804  0.1709898 ]\n"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stderr",
       "text": [
        "-c:42: RuntimeWarning: divide by zero encountered in log\n",
        "-c:42: RuntimeWarning: invalid value encountered in double_scalars\n",
        "-c:2: RuntimeWarning: divide by zero encountered in log\n",
        "-c:2: RuntimeWarning: invalid value encountered in double_scalars\n"
       ]
      },
      {
       "metadata": {},
       "output_type": "display_data",
       "png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEt5JREFUeJzt3X+s3fd91/Hna3adLqogznxTOv9I\nXHDUtbRKpjOjLZSlG24NiKSMKXMktGwTNRKkfxS1wlGRGB6T1pYpMLA0eVOlbVLiZYFGF6biZm1K\nS0nAx1vaxg52bh3A1y2LlzqMLFodp2/+OF9rX99e555777k+dj/Ph3Tk8/18399z329bep2vv99z\n7FQVkqQ2fM+0G5AkXTmGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakh68cpSrIb\n+NfAOuDXq+qXFux/EHhPt3k9cFNV3ZDkZuBTjN5c3gD8m6r61df7WZs2bapbbrllWUNIUuuOHj36\nR1U1s1RdlvpnGJKsA04Cu4B54Ahwb1Udv0z9B4Hbq+rnkmzofsa3krwJeAb4kar6+uV+3mAwqOFw\nuFTfkqSeJEerarBU3TiXd3YCc1V1qqrOA4eAu1+n/l7gYYCqOl9V3+rWrxvz50mS1sg4IbwZON3b\nnu/WvkN3OWc78Lne2tYkX+le42Ovd5YvSVpbkz7z3gM8WlWvXVyoqtNV9S7gLwH3JXnzwoOS7E0y\nTDI8e/bshFuSJF00TuifAbb2trd0a4vZQ3dpZ6HuDP8Z4N2L7DtYVYOqGszMLHkfQpK0QuOE/hFg\nR5Lt3Y3ZPcDswqIkbwM2Ak/21rYk+d7u+UbgrwInJtG4JGn5lvzIZlVdSHI/cJjRRzY/WVXHkuwH\nhlV18Q1gD3CoLv040A8Av5ykgAD/sqq+OtkRJEnjWvIjm1eaH9mUpOWb5Ec2JUnfJQx9SWqIoS9J\nDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQ\nQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSFjhX6S3UlOJJlLsm+R/Q8mebp7nEzyUrd+W5InkxxL8pUk\nPzXpASRJ41u/VEGSdcABYBcwDxxJMltVxy/WVNWHevUfBG7vNl8Bfrqqnkvy/cDRJIer6qVJDiFJ\nGs84Z/o7gbmqOlVV54FDwN2vU38v8DBAVZ2sque6518HXgBmVteyJGmlxgn9zcDp3vZ8t/YdktwM\nbAc+t8i+ncAG4GvLb1OSNAmTvpG7B3i0ql7rLyZ5C/BbwM9W1bcXHpRkb5JhkuHZs2cn3JIk6aJx\nQv8MsLW3vaVbW8weuks7FyX5c8DvAh+tqqcWO6iqDlbVoKoGMzNe/ZGktTJO6B8BdiTZnmQDo2Cf\nXViU5G3ARuDJ3toG4FPAb1bVo5NpWZK0UkuGflVdAO4HDgPPAo9U1bEk+5Pc1SvdAxyqquqt3QP8\nNeBneh/pvG2C/UuSliGXZvT0DQaDGg6H025Dkq4pSY5W1WCpOr+RK0kNMfQlqSGGviQ1xNCXpIYY\n+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEv\nSQ0x9CWpIYa+JDVkrNBPsjvJiSRzSfYtsv/B3n98fjLJS719/ynJS0n+4yQblyQt3/qlCpKsAw4A\nu4B54EiS2ao6frGmqj7Uq/8gcHvvJT4BXA/8g0k1LUlamXHO9HcCc1V1qqrOA4eAu1+n/l7g4Ysb\nVfVZ4P+tqktJ0kSME/qbgdO97flu7TskuRnYDnxu9a1JkiZt0jdy9wCPVtVryzkoyd4kwyTDs2fP\nTrglSdJF44T+GWBrb3tLt7aYPfQu7Yyrqg5W1aCqBjMzM8s9XJI0pnFC/wiwI8n2JBsYBfvswqIk\nbwM2Ak9OtkVJ0qQsGfpVdQG4HzgMPAs8UlXHkuxPclevdA9wqKqqf3ySLwK/A/x4kvkk75tc+5Kk\n5ciCjJ66wWBQw+Fw2m1I0jUlydGqGixV5zdyJakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlq\niKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY\n+pLUkLFCP8nuJCeSzCXZt8j+B5M83T1OJnmpt+++JM91j/sm2bwkaXnWL1WQZB1wANgFzANHksxW\n1fGLNVX1oV79B4Hbu+c3Av8MGAAFHO2OPTfRKSRJYxnnTH8nMFdVp6rqPHAIuPt16u8FHu6evw94\nvKq+2QX948Du1TQsSVq5cUJ/M3C6tz3frX2HJDcD24HPLfdYSdLam/SN3D3Ao1X12nIOSrI3yTDJ\n8OzZsxNuSZJ00TihfwbY2tve0q0tZg9/dmln7GOr6mBVDapqMDMzM0ZLkqSVGCf0jwA7kmxPsoFR\nsM8uLEryNmAj8GRv+TDw3iQbk2wE3tutSZKmYMlP71TVhST3MwrrdcAnq+pYkv3AsKouvgHsAQ5V\nVfWO/WaSX2D0xgGwv6q+OdkRJEnjSi+jrwqDwaCGw+G025Cka0qSo1U1WKrOb+RKUkMMfUlqiKEv\nSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLU\nEENfkhpi6EtSQwx9SWqIoS9JDRkr9JPsTnIiyVySfZepuSfJ8STHkjzUW/9Ykme6x09NqnFJ0vKt\nX6ogyTrgALALmAeOJJmtquO9mh3AA8AdVXUuyU3d+t8CfhC4DbgO+HyST1fVH09+FEnSUsY5098J\nzFXVqao6DxwC7l5Q8wHgQFWdA6iqF7r1twNfqKoLVfUnwFeA3ZNpXZK0XOOE/mbgdG97vlvruxW4\nNcmXkjyV5GKwfxnYneT6JJuA9wBbV9u0JGlllry8s4zX2QHcCWwBvpDknVX1mSQ/BPxX4CzwJPDa\nwoOT7AX2Amzbtm1CLUmSFhrnTP8Ml56db+nW+uaB2ap6taqeB04yehOgqn6xqm6rql1Aun2XqKqD\nVTWoqsHMzMxK5pAkjWGc0D8C7EiyPckGYA8wu6DmMUZn+XSXcW4FTiVZl+T7uvV3Ae8CPjOh3iVJ\ny7Tk5Z2qupDkfuAwsA74ZFUdS7IfGFbVbLfvvUmOM7p885GqejHJG4EvJgH4Y+DvVdWFtRpGkvT6\nUlXT7uESg8GghsPhtNuQpGtKkqNVNViqzm/kSlJDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLU\nEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x\n9CWpIWOFfpLdSU4kmUuy7zI19yQ5nuRYkod66x/v1p5N8itJMqnmJUnLs36pgiTrgAPALmAeOJJk\ntqqO92p2AA8Ad1TVuSQ3des/AtwBvKsr/S/AjwKfn+QQkqTxjHOmvxOYq6pTVXUeOATcvaDmA8CB\nqjoHUFUvdOsFvBHYAFwHvAH4w0k0LklavnFCfzNwurc936313QrcmuRLSZ5Kshugqp4EngC+0T0O\nV9Wzq29bkrQSS17eWcbr7ADuBLYAX0jyTmAT8APdGsDjSd5dVV/sH5xkL7AXYNu2bRNqSZK00Dhn\n+meArb3tLd1a3zwwW1WvVtXzwElGbwJ/B3iqql6uqpeBTwM/vPAHVNXBqhpU1WBmZmYlc0iSxjBO\n6B8BdiTZnmQDsAeYXVDzGKOzfJJsYnS55xTwv4EfTbI+yRsY3cT18o4kTcmSoV9VF4D7gcOMAvuR\nqjqWZH+Su7qyw8CLSY4zuob/kap6EXgU+BrwVeDLwJer6j+swRySpDGkqqbdwyUGg0ENh8NptyFJ\n15QkR6tqsFSd38iVpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBD\nX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGjJW6CfZneREkrkk+y5Tc0+S\n40mOJXmoW3tPkqd7jz9N8v5JDiBJGt/6pQqSrAMOALuAeeBIktmqOt6r2QE8ANxRVeeS3ARQVU8A\nt3U1NwJzwGcmPoUkaSzjnOnvBOaq6lRVnQcOAXcvqPkAcKCqzgFU1QuLvM5PAp+uqldW07AkaeXG\nCf3NwOne9ny31ncrcGuSLyV5KsnuRV5nD/DwytqUJE3Ckpd3lvE6O4A7gS3AF5K8s6peAkjyFuCd\nwOHFDk6yF9gLsG3btgm1JElaaJwz/TPA1t72lm6tbx6YrapXq+p54CSjN4GL7gE+VVWvLvYDqupg\nVQ2qajAzMzN+95KkZRkn9I8AO5JsT7KB0WWa2QU1jzE6yyfJJkaXe0719t+Ll3YkaeqWDP2qugDc\nz+jSzLPAI1V1LMn+JHd1ZYeBF5McB54APlJVLwIkuYXR3xT+8+TblyQtR6pq2j1cYjAY1HA4nHYb\nknRNSXK0qgZL1fmNXElqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kN\nMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNWSs0E+yO8mJJHNJ\n9l2m5p4kx5McS/JQb31bks8kebbbf8tkWpckLdf6pQqSrAMOALuAeeBIktmqOt6r2QE8ANxRVeeS\n3NR7id8EfrGqHk/yJuDbE51AkjS2cc70dwJzVXWqqs4Dh4C7F9R8ADhQVecAquoFgCRvB9ZX1ePd\n+stV9crEupckLcs4ob8ZON3bnu/W+m4Fbk3ypSRPJdndW38pyb9P8gdJPtH9zUGSNAWTupG7HtgB\n3AncC/xakhu69XcDHwZ+CHgr8DMLD06yN8kwyfDs2bMTakmStNA4oX8G2Nrb3tKt9c0Ds1X1alU9\nD5xk9CYwDzzdXRq6ADwG/ODCH1BVB6tqUFWDmZmZlcwhSRrDOKF/BNiRZHuSDcAeYHZBzWOMzvJJ\nsonRZZ1T3bE3JLmY5D8GHEeSNBVLhn53hn4/cBh4Fnikqo4l2Z/krq7sMPBikuPAE8BHqurFqnqN\n0aWdzyb5KhDg19ZiEEnS0lJV0+7hEoPBoIbD4bTbkKRrSpKjVTVYqs5v5EpSQwx9SWqIoS9JDTH0\nJakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkOuum/kJjkL/K9p97ECm4A/mnYTV5gzt8GZrw03V9WS\n/2LlVRf616okw3G+Av3dxJnb4MzfXby8I0kNMfQlqSGG/uQcnHYDU+DMbXDm7yJe05ekhnimL0kN\nMfSXIcmNSR5P8lz368bL1N3X1TyX5L5F9s8meWbtO1691cyc5Pokv5vkfyQ5luSXrmz340uyO8mJ\nJHNJ9i2y/7okv93t/29Jbunte6BbP5HkfVey79VY6cxJdiU5muSr3a8/dqV7X6nV/Dl3+7cleTnJ\nh69UzxNXVT7GfAAfB/Z1z/cBH1uk5kZG/z/wjcDG7vnG3v6fAB4Cnpn2PGs9M3A98J6uZgPwReBv\nTHumRfpfB3wNeGvX55eBty+o+YfAr3bP9wC/3T1/e1d/HbC9e511055pjWe+Hfj+7vlfBs5Me561\nnrm3/1Hgd4APT3uelT4801+eu4Hf6J7/BvD+RWreBzxeVd+sqnPA48BugCRvAv4x8C+uQK+TsuKZ\nq+qVqnoCoKrOA78PbLkCPS/XTmCuqk51fR5iNHdf//fhUeDHk6RbP1RV36qq54G57vWudiueuar+\noKq+3q0fA743yXVXpOvVWc2fM0neDzzPaOZrlqG/PG+uqm90z/8P8OZFajYDp3vb890awC8Avwy8\nsmYdTt5qZwYgyQ3A3wY+uxZNrtKS/fdrquoC8H+B7xvz2KvRambu+7vA71fVt9aoz0la8czdCds/\nAf75FehzTa2fdgNXmyS/B/yFRXZ9tL9RVZVk7I8+JbkN+ItV9aGF1wmnba1m7r3+euBh4Feq6tTK\nutTVJsk7gI8B7512L1fAzwMPVtXL3Yn/NcvQX6Cq/vrl9iX5wyRvqapvJHkL8MIiZWeAO3vbW4DP\nAz8MDJL8T0a/7zcl+XxV3cmUreHMFx0EnquqfzWBdtfCGWBrb3tLt7ZYzXz3JvbngRfHPPZqtJqZ\nSbIF+BTw01X1tbVvdyJWM/NfAX4yyceBG4BvJ/nTqvq3a9/2hE37psK19AA+waU3NT++SM2NjK77\nbewezwM3Lqi5hWvnRu6qZmZ0/+LfAd8z7VleZ8b1jG4+b+fPbvC9Y0HNP+LSG3yPdM/fwaU3ck9x\nbdzIXc3MN3T1PzHtOa7UzAtqfp5r+Ebu1Bu4lh6Mrmd+FngO+L1esA2AX+/V/RyjG3pzwM8u8jrX\nUuiveGZGZ1IFPAs83T3+/rRnusycfxM4yejTHR/t1vYDd3XP38joUxtzwH8H3to79qPdcSe4Cj+d\nNOmZgX8K/Envz/Rp4KZpz7PWf86917imQ99v5EpSQ/z0jiQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9\nSWqIoS9JDTH0Jakh/x+EzKx7zkiI+gAAAABJRU5ErkJggg==\n",
       "text": [
        "<matplotlib.figure.Figure at 0x7f1179072f28>"
       ]
      }
     ],
     "prompt_number": 6
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "val_data = np.load('val_data.npy')\n",
      "val_label = np.load('val_label.npy')\n",
      "def predict(val_data, val_label):\n",
      "    loss = 0\n",
      "    num = len(val_data)\n",
      "    count = 0\n",
      "    for i in range(num):\n",
      "        z = np.dot(val_data[i], weight) + bias\n",
      "        a = sigmoid(z)\n",
      "        if (a<0.5 and val_label[i] == 0):\n",
      "            count += 1\n",
      "        elif (val_label[i] == 1):\n",
      "            count += 1\n",
      "        loss += compute_loss(a, val_label[i])\n",
      "    print (loss/num)\n",
      "    print (count/num)\n",
      "predict(val_data, val_label)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "0.172008175255\n",
        "0.9840319361277445\n"
       ]
      }
     ],
     "prompt_number": 7
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [],
     "language": "python",
     "metadata": {},
     "outputs": []
    }
   ],
   "metadata": {}
  }
 ]
}